import torch

def softmax(X):
    X_exp = X.exp()
    partition = X_exp.sum(dim=1, keepdims=True)
    print(partition)
    return  X_exp / partition #  这里应用了广播机制

X = torch.rand((2, 5))
print(X)
print("-----------------------------")
x_prob = softmax(X)
print(x_prob)
print("--------------------------------")
print(x_prob, x_prob.sum(dim=1))